{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import os\n",
    "import mne\n",
    "import pywt\n",
    "import keras\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from MyEDFImports import import_ecg, get_edf_filenames, import_eeg, load_all_labels, load_all_data, stages_names, stages_names_3_outputs"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "efds = get_edf_filenames()\n",
    "tmin, tmax = 100, 101\n",
    "ecg = import_ecg(name=efds[0], tmin=tmin, tmax=tmax)\n",
    "one_eeg = import_eeg(name=efds[0], tmin=tmin, tmax=tmax)\n",
    "\n",
    "# printing info\n",
    "print(one_eeg.info)\n",
    "print(ecg.info)\n",
    "# plotting ECG\n",
    "y = ecg[0][0].T\n",
    "x = ecg[0][1]\n",
    "plt.figure(figsize=(20, 6))\n",
    "plt.step(x, y, )\n",
    "plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))\n",
    "plt.legend(['ECG'])\n",
    "plt.xlabel('time [s]')\n",
    "plt.ylabel('voltage [V]')\n",
    "plt.title(f\"ECG of CN223100.edf from {tmin}s to {tmax}s\")\n",
    "plt.show()\n",
    "# Plotting EEG\n",
    "y = one_eeg[0][0].T\n",
    "x = one_eeg[0][1]\n",
    "plt.figure(figsize=(20, 6))\n",
    "plt.step(x, y)\n",
    "ax = plt.gca()\n",
    "ax.invert_yaxis()\n",
    "plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))\n",
    "plt.legend(['EEG'])\n",
    "plt.xlabel('time [s]')\n",
    "plt.ylabel('voltage [V]')\n",
    "plt.title(f\"FP1A2 EEG of CN223100.edf from {tmin}s to {tmax}s\")\n",
    "plt.show()\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "data = load_all_data()\n",
    "labels = load_all_labels()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "\n",
    "counted_labels = Counter(labels)\n",
    "counted_labels"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def three_stages_transform(n: int):\n",
    "    if n == 0:\n",
    "        return 0\n",
    "    if n == 5:\n",
    "        return 2\n",
    "    return 1\n",
    "\n",
    "\n",
    "labels_3_stages = np.array(list(map(three_stages_transform, labels)))\n",
    "counted_labels_3_stages = Counter(labels_3_stages)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "scales = range(1, 128)\n",
    "waveletname = 'morl'"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "for key in counted_labels_3_stages.keys():\n",
    "    random_ind = np.random.choice(np.where(labels == key)[0])\n",
    "    random_slice = data[random_ind]\n",
    "    coef, freq = pywt.cwt(random_slice, scales, waveletname)\n",
    "    plt.imshow(coef, extent=[-1, 1, 1, 128], cmap='PRGn', aspect='auto',\n",
    "               vmax=abs(coef).max(), vmin=-abs(coef).max(), label=key)\n",
    "    plt.title(stages_names_3_outputs[key])\n",
    "    plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pywt\n",
    "\n",
    "# choose random for slice for each key\n",
    "for key in counted_labels.keys():\n",
    "    random_ind = np.random.choice(np.where(labels == key)[0])\n",
    "    random_slice = data[random_ind]\n",
    "    coef, freq = pywt.cwt(random_slice, scales, waveletname)\n",
    "    plt.imshow(coef, extent=[-1, 1, 1, 128], cmap='PRGn', aspect='auto',\n",
    "               vmax=abs(coef).max(), vmin=-abs(coef).max())\n",
    "    plt.title(stages_names[key])\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
